import json
import os
import re
import glob
import argparse
from collections import defaultdict

def extract_brand_and_year(text: str):
    """Extracts brand and year from a text string using regex."""
    brand_match = re.search(r"brand\s*:\s*([A-Za-z0-9_\-]+)", text, flags=re.IGNORECASE)
    brand = brand_match.group(1).lower() if brand_match else "unknown"
    year_match = re.search(r"\b(19|20)\d{2}\b", text)
    year = year_match.group(0) if year_match else "unknown"
    return brand, year

def main():
    parser = argparse.ArgumentParser(description="Evaluate brand and time prediction accuracy from JSON files.")
    parser.add_argument(
        "--input_dir",
        type=str,
        default="twitter_qwen_results",
        help="Directory containing the merged JSON files to evaluate."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="twitter_qwen_results",
        help="Directory to save the detailed evaluation results."
    )
    args = parser.parse_args()

    if not os.path.isdir(args.input_dir):
        print(f"Error: Input directory not found at '{args.input_dir}'")
        return

    json_files = glob.glob(os.path.join(args.input_dir, "*_merged_*.json"))
    if not json_files:
        print(f"No *_merged_*.json files found in '{args.input_dir}'")
        return

    os.makedirs(args.output_dir, exist_ok=True)

    for json_path in json_files:
        print(f"Processing {json_path}...")
        try:
            with open(json_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except (json.JSONDecodeError, FileNotFoundError) as e:
            print(f"  Error reading or parsing {json_path}: {e}")
            continue

        brand_correct = 0
        time_correct = 0
        brand_total = 0
        time_total = 0
        evaluation_results = []

        for item in data:
            gt_label = item.get("ground_truth", "unknown")
            pred_label = item.get("prediction_label", "unknown")
            reason = item.get("response", "")

            # Extract brand and year from ground truth and prediction
            gt_brand, gt_year = extract_brand_and_year(gt_label)
            pred_brand, pred_year = extract_brand_and_year(pred_label)

            # If prediction is null/unknown, try to extract from reason
            if pred_brand == "unknown" or pred_year == "unknown":
                r_brand, r_year = extract_brand_and_year(reason)
                if pred_brand == "unknown" and r_brand != "unknown":
                    pred_brand = r_brand
                if pred_year == "unknown" and r_year != "unknown":
                    pred_year = r_year

            is_brand_correct = 0
            if gt_brand != "unknown":
                brand_total += 1
                if gt_brand == pred_brand:
                    brand_correct += 1
                    is_brand_correct = 1

            is_time_correct = 0
            if gt_year != "unknown":
                time_total += 1
                if gt_year == pred_year:
                    time_correct += 1
                    is_time_correct = 1

            evaluation_results.append({
                "ground_truth": gt_label,
                "prediction_label": pred_label,
                "response": reason,
                "gt_brand": gt_brand,
                "gt_year": gt_year,
                "pred_brand": pred_brand,
                "pred_year": pred_year,
                "brand_correct": is_brand_correct,
                "time_correct": is_time_correct
            })

        print("-" * 30)
        if brand_total > 0:
            brand_accuracy = brand_correct / brand_total
            print(f"  Brand Accuracy: {brand_accuracy:.2%} ({brand_correct}/{brand_total})")
        else:
            print("  Brand accuracy could not be calculated (no ground truth brands found).")

        if time_total > 0:
            time_accuracy = time_correct / time_total
            print(f"  Time Accuracy: {time_accuracy:.2%} ({time_correct}/{time_total})")
        else:
            print("  Time accuracy could not be calculated (no ground truth years found).")
        print("-" * 30)

        results_filename = os.path.join(args.output_dir, f"evaluation_results_{os.path.basename(json_path)}")
        with open(results_filename, 'w', encoding='utf-8') as f_out:
            json.dump(evaluation_results, f_out, indent=2)
        print(f"Detailed evaluation saved to {results_filename}\n")

if __name__ == "__main__":
    main() 